import os
import h5py
import numpy as np
import argparse
from IPython import embed
from pytorch3d.loss import chamfer_distance
import torch
from pytorch3d.ops import cubify, sample_points_from_meshes
from scipy.optimize import linear_sum_assignment
import sys
sys.path.insert(0, '/home/tiangel/ShapeGF')
try:
    from evaluation.evaluation_metrics import EMD_CD
    eval_reconstruciton = True
except:  # noqa
    # Skip evaluation
    eval_reconstruciton = False

parser = argparse.ArgumentParser()

parser.add_argument('--source1', type = str, required = True,
                    help='path to your trained DALL-E')
parser.add_argument('--source2', type = str, required = True,
                    help='path to your trained DALL-E')
parser.add_argument('--emd', type = bool, default = False,
                    help='path to your trained DALL-E')
parser.add_argument('--ori', type = bool, default = False,
                    help='path to your trained DALL-E')

args = parser.parse_args()

# if args.ori:
    # ours_shape_h5 = h5py.File(os.path.join('./shape2prog/output/', args.category, 'shapes.h5'), 'r')
    # ours_shapes = np.array(ours_shape_h5['data'])
# else:
    # ours_shape_h5 = h5py.File(os.path.join('./shape2prog/vqprogram_outputs/', 'test'+args.save_name, 'pred', args.category+'.h5'), 'r')
    # ours_shapes = np.array(ours_shape_h5['shape'])

ours_shape_h5 = h5py.File(args.source1, 'r')
ours_pc = np.array(ours_shape_h5['data'])
target_shape_h5 = h5py.File(args.source2, 'r')
target_pc = np.array(target_shape_h5['data'])
bs = 256
cd_dis = []
emd_dis = []
if eval_reconstruciton:
    for i in range(int(ours_pc.shape[0]/bs)):
        rec_res = EMD_CD(torch.from_numpy(ours_pc[i*bs:(i+1)*bs]).cuda(), torch.from_numpy(target_pc[i*bs:(i+1)*bs]).cuda(), 1)
        print(i, 'CD:', rec_res['MMD-CD'], 'EMD:', rec_res['MMD-EMD'])
        cd_dis.append(rec_res['MMD-CD'])
        emd_dis.append(rec_res['MMD-EMD'])
else:
    print('eval_reconstruciton is False')
    embed()
    exit()
embed()
exit()
for i in range(np.ceil(ours_pc.shape[0]/bs).astype(np.int)):
    cd_dis.append(chamfer_distance(torch.Tensor(ours_pc[i*bs:(i+1)*bs]).type(torch.double).cuda(), torch.Tensor(target_pc[i*bs:(i+1)*bs]).cuda())[0])

print('cd_dis:', torch.mean(torch.Tensor(cd_dis)))
if args.emd:
    emd_dis = []
    dim = 2048
    for i in range(ours_pc.shape[0]):
        print('emd',i)
        q1 = ours_pc[i]
        q2 = target_pc[i]
        t1 = np.repeat(q1,dim,axis=0).reshape(dim,dim,3)
        t2 = np.swapaxes(np.repeat(q2,dim,axis=0).reshape(dim,dim,3), 0, 1)
        diff = t1-t2
        matrix = diff[:,:,0]*diff[:,:,0]+diff[:,:,1]*diff[:,:,1]+diff[:,:,2]*diff[:,:,2]
        row_ind, col_ind = linear_sum_assignment(matrix)
        diff2=q1 - q2[col_ind]
        # diff2 = q1 - q2
        emd_dis.append(np.mean(np.sqrt(diff2[:,0]*diff2[:,0]+diff2[:,1]*diff2[:,1]+diff2[:,2]*diff2[:,2])))
    print('emd_dis:', np.mean(np.array(emd_dis)))

